'''
将数据划分为不同类别，并统计各自占比。
'''

import os
import json

from tqdm import tqdm
from count_chi import getClass
from collections import defaultdict

from count_chi import INPUT_FILE, OUTPUT_DIR

print(f'{INPUT_FILE = }')
print(f'{OUTPUT_DIR = }')
os.makedirs(OUTPUT_DIR, exist_ok = True)

def main():
    writerDict = dict()
    countDict = defaultdict(int)
    with open(INPUT_FILE, 'r', encoding = 'utf-8') as file:
        for line in tqdm(file):
            obj = json.loads(line)
            label = getClass(obj)
            countDict[label] += 1
            if label not in writerDict:
                writerDict[label] = open(os.path.join(OUTPUT_DIR, f'split-{label}.jsonl'), 'w', encoding = 'utf-8')
            writer = writerDict.get(label)
            writer.write(line)
    for writer in writerDict.values():
        writer.close() 
    print(f'{countDict = }')

if __name__ == '__main__':
    main()
